{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9f85f826",
   "metadata": {},
   "source": [
    "It is useful to cluster the 1B datasets to around 262k - 1M clusters for IVF indexing with Faiss.\n",
    "However, it is not feasible to do the clustering within the allocated time for indexing. \n",
    "\n",
    "Therefore, here we evaluate other options to break down the clustering cost, while getting the same number of clusters.\n",
    "The model that we use is: Deep1M (1M database vectors), 4096 clusters (which conveniently breaks down to 2^6 * 2^6)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ongoing-first",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import faiss\n",
    "from faiss.contrib import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "finnish-giant",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds = datasets.DatasetDeep1B(10**6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "satisfied-adoption",
   "metadata": {},
   "outputs": [],
   "source": [
    "xt = ds.get_train(10**5)\n",
    "d = ds.d\n",
    "xb = ds.get_database()\n",
    "xq = ds.get_queries()\n",
    "gt = ds.get_groundtruth()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "genetic-sleep",
   "metadata": {},
   "outputs": [],
   "source": [
    "sqrt_nlist = 64\n",
    "nlist = sqrt_nlist**2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "indoor-client",
   "metadata": {},
   "source": [
    "# Flat quantizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a18ca10",
   "metadata": {},
   "source": [
    "Flat quantizer is what we would like to apprach, but it probably too costly. \n",
    "We include it here as a topline.\n",
    "The measure we use is recall of nearest neighbor vs. number of computed distances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "romance-pacific",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantizer = faiss.IndexFlatL2(d)\n",
    "index = faiss.IndexIVFFlat(quantizer, d, nlist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "noble-possession",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.431187283968"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.train(xt)\n",
    "index.add(xb)\n",
    "index.invlists.imbalance_factor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "described-chicago",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nprobe=1 1-recall @ 1: 0.3745 dis/q=349.15\n",
      "nprobe=4 1-recall @ 1: 0.6849 dis/q=1344.67\n",
      "nprobe=16 1-recall @ 1: 0.9004 dis/q=5040.35\n",
      "nprobe=64 1-recall @ 1: 0.9793 dis/q=18331.49\n"
     ]
    }
   ],
   "source": [
    "stats = faiss.cvar.indexIVF_stats\n",
    "for nprobe in 1, 4, 16, 64: \n",
    "    index.nprobe = nprobe \n",
    "    stats.reset()\n",
    "    D, I = index.search(xq, 100)\n",
    "    rank = 1\n",
    "    recall = (I[:, :rank] == gt[:, :1]).sum() / len(xq)\n",
    "    print(f\"nprobe={nprobe} 1-recall @ {rank}: {recall} dis/q={stats.ndis/len(xq):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "realistic-valve",
   "metadata": {},
   "source": [
    "# IMI quantizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0388876",
   "metadata": {},
   "source": [
    "The IMI quantizer is a cheap way of breaking down the dataset into buckets. It is a PQ2x6 and each PQ code ends in a separate bucket. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "amateur-earth",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantizer = faiss.MultiIndexQuantizer(d, 2, int(np.log2(sqrt_nlist)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "unsigned-motorcycle",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = faiss.IndexIVFFlat(quantizer, d, nlist)\n",
    "index.quantizer_trains_alone = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "organizational-impossible",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16.421237645312"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.train(xt)\n",
    "index.add(xb)\n",
    "index.invlists.imbalance_factor()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7be36ece",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nprobe=1 1-recall @ 1: 0.437 dis/q=3972.32\n",
      "nprobe=4 1-recall @ 1: 0.6948 dis/q=9210.20\n",
      "nprobe=16 1-recall @ 1: 0.8656 dis/q=19246.74\n",
      "nprobe=64 1-recall @ 1: 0.9613 dis/q=41114.89\n"
     ]
    }
   ],
   "source": [
    "stats = faiss.cvar.indexIVF_stats\n",
    "\n",
    "for nprobe in 1, 4, 16, 64: \n",
    "    index.nprobe = nprobe \n",
    "    stats.reset()\n",
    "\n",
    "    D, I = index.search(xq, 100)\n",
    "    rank = 1\n",
    "    recall = (I[:, :rank] == gt[:, :1]).sum() / len(xq)\n",
    "    print(f\"nprobe={nprobe} 1-recall @ {rank}: {recall} dis/q={stats.ndis/len(xq):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc986a53",
   "metadata": {},
   "source": [
    "So way less efficient than the flat quantizer, due to imbalanced inverted lists. TBH, the IMI quantizer usually sets a cap on the number of distances rather than fixing the number of visited buckets. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "south-differential",
   "metadata": {},
   "source": [
    "# Residual quantizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e5910d8",
   "metadata": {},
   "source": [
    "This is a 2-level additive quantizer where the first level is trained first, then the second. Since it is an additive quantizer, the top-k centroids can be retrieved efficiently with lookup tables. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "elect-vacation",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantizer = faiss.ResidualCoarseQuantizer(d, 2, int(np.log2(sqrt_nlist)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "nervous-lesbian",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = faiss.IndexIVFFlat(quantizer, d, nlist)\n",
    "index.quantizer_trains_alone = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "ae530558",
   "metadata": {},
   "outputs": [],
   "source": [
    "index.train(xt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "ceaa6077",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantizer.set_beam_factor(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "3eb25d40",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3.604173447168"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.add(xb)\n",
    "index.invlists.imbalance_factor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "af3a02de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nprobe=1 1-recall @ 1: 0.3079 dis/q=878.77\n",
      "nprobe=4 1-recall @ 1: 0.6091 dis/q=3017.90\n",
      "nprobe=16 1-recall @ 1: 0.8608 dis/q=9996.18\n",
      "nprobe=64 1-recall @ 1: 0.9685 dis/q=31318.18\n"
     ]
    }
   ],
   "source": [
    "stats = faiss.cvar.indexIVF_stats\n",
    "\n",
    "for nprobe in 1, 4, 16, 64: \n",
    "    index.nprobe = nprobe \n",
    "    stats.reset()\n",
    "\n",
    "    D, I = index.search(xq, 100)\n",
    "    rank = 1\n",
    "    recall = (I[:, :rank] == gt[:, :1]).sum() / len(xq)\n",
    "    print(f\"nprobe={nprobe} 1-recall @ {rank}: {recall} dis/q={stats.ndis/len(xq):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9db020c",
   "metadata": {},
   "source": [
    "Unfortunately still not very good.  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a0514ef",
   "metadata": {},
   "source": [
    "# 2-level tree quantizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "adfc3b23",
   "metadata": {},
   "source": [
    "This is a suggestion by Harsha: just cluster to 64 centroids at the first level and train separate clusterings within each bucket."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "4f86ff7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1st level quantizer "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "8157ef04",
   "metadata": {},
   "outputs": [],
   "source": [
    "km = faiss.Kmeans(d, sqrt_nlist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "29b154ce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9879.4462890625"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "km.train(xt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "27a355a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "centroids1 = km.centroids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "6083d36f",
   "metadata": {},
   "outputs": [],
   "source": [
    "xt2 = ds.get_train(500_000)\n",
    "\n",
    "_, assign1 = km.assign(xt2)\n",
    "bc = np.bincount(assign1)\n",
    "o = assign1.argsort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "32e64dfb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "63\r"
     ]
    }
   ],
   "source": [
    "i0 = 0\n",
    "c2 = []\n",
    "for c1 in range(sqrt_nlist): \n",
    "    print(c1, end=\"\\r\", flush=True)\n",
    "    i1 = i0 + bc[c1]\n",
    "    subset = o[i0:i1]\n",
    "    assert np.all(assign1[subset] == c1)\n",
    "    km = faiss.Kmeans(d, sqrt_nlist)\n",
    "    xtsub = xt2[subset]\n",
    "    km.train(xtsub)\n",
    "    c2.append(km.centroids)\n",
    "    i0 = i1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "704c495a",
   "metadata": {},
   "source": [
    "Then we just stack the centroids together and forget about the first level clustering. \n",
    "In reality with 262k-1M clusters, we'll train a HNSW or NSG index on top.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "b41aeeae",
   "metadata": {},
   "outputs": [],
   "source": [
    "centroids12 = np.vstack(c2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "7041f966",
   "metadata": {},
   "outputs": [],
   "source": [
    "quantizer = faiss.IndexFlatL2(d)\n",
    "quantizer.add(centroids12)\n",
    "index = faiss.IndexIVFFlat(quantizer, d, nlist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "1bf4175d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.200742457344"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index.add(xb)\n",
    "index.invlists.imbalance_factor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "id": "6d2acf15",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "nprobe=1 1-recall @ 1: 0.3774 dis/q=291.20\n",
      "nprobe=4 1-recall @ 1: 0.6847 dis/q=1153.03\n",
      "nprobe=16 1-recall @ 1: 0.8995 dis/q=4459.66\n",
      "nprobe=64 1-recall @ 1: 0.9825 dis/q=16942.70\n"
     ]
    }
   ],
   "source": [
    "stats = faiss.cvar.indexIVF_stats\n",
    "for nprobe in 1, 4, 16, 64: \n",
    "    index.nprobe = nprobe \n",
    "    stats.reset()\n",
    "    D, I = index.search(xq, 100)\n",
    "    rank = 1\n",
    "    recall = (I[:, :rank] == gt[:, :1]).sum() / len(xq)\n",
    "    print(f\"nprobe={nprobe} 1-recall @ {rank}: {recall} dis/q={stats.ndis/len(xq):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35c0a565",
   "metadata": {},
   "source": [
    "Turns out this is very good: same level of accuracy as the flat topline!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4f1c3a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
